import os
from PIL import Image
from tqdm import tqdm
import torch
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import matplotlib.pyplot as plt
import random
import warnings
warnings.filterwarnings('ignore', category=UserWarning)


sam_bak = None
mask_generator_bak = None

def auto_sam(image_path = "", tmp_path = "", device = 'cuda:0',
            auto_sam_checkpoint_path = '/home/zry/datasets/utils/auto-sam/checkpoints/sam_vit_b_01ec64.pth'):

    global sam_bak, mask_generator_bak
    if sam_bak is None:
        sam_bak = sam = sam_model_registry["vit_b"](checkpoint=auto_sam_checkpoint_path).to(device)
        for name, parameter in sam.named_parameters():
            parameter.requires_grad = False
        mask_generator_bak = mask_generator = SamAutomaticMaskGenerator(points_per_side = 32, model = sam)
    else:
        sam, mask_generator = sam_bak, mask_generator_bak

    image = Image.open(image_path)
    image_array = np.array(image)

    # mask的数量，根据不同的图片有所不同
    masks = mask_generator.generate(image_array)
    # seg_image_files = []

    for i, mask in tqdm(enumerate(masks), total=len(masks), desc='saving bbox', leave=False):
        # output_image = np.ones_like(image_array) * 255
        # mseg = mask['segmentation']
        # output_image[mseg] = image_array[mseg]

        # image = Image.fromarray(output_image)
        # image.save( f"mask_image_%s.png" % i)
        
        # 对图片进行更加细粒度地分割
        # x, y, width, height = [int(i) for i  in mask['bbox']]
        # cropped_image = output_image[y:y+height, x:x+width]
        # cimage = Image.fromarray(cropped_image)
        image_label = os.path.split(image_path)[-1] 
        # output_seg_path = os.path.join(tmp_path, image_label + f".%s.jpg" % i)
        # cimage.save(output_seg_path)
        # seg_image_files.append(output_seg_path)

        output_bbox_path = os.path.join(tmp_path, image_label + f".%s.bbox" % i)
        torch.save(mask['bbox'], output_bbox_path)

    # 主要是为了存储每个masks所对应的位置
    return masks